from transformers import InstructBlipForConditionalGeneration, InstructBlipProcessor,AutoProcessor,MllamaForConditionalGeneration,LlavaForConditionalGeneration
import numpy as np
import torch.nn as nn
import torchvision
import torch
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm import tqdm
import json
from matplotlib.pyplot import figure
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import random
from PIL import Image


def set_deterministic(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed) 
    torch.cuda.manual_seed_all(seed)

torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
size = 512
train_transform = transforms.Compose([
        transforms.Resize((size,size)),
        transforms.ToTensor(),
])

val_transform = transforms.Compose([
        transforms.Resize((size,size)),
        transforms.ToTensor(),
])
set_deterministic(0)

import re
def getdata(task, batch_size = 1):

    data_dir = './dataset'
    
    if task == "ImageNet":
        train_dataset = datasets.ImageNet(
            root=data_dir,split="val",transform=val_transform,
        )
        validation_dataset = datasets.ImageNet(
            root=data_dir,split="val",
             transform=val_transform,
        )
        num_classes=1000
 
    train_dataset.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.Resize((size,size)),
            transforms.ToTensor(),
    ])
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size,shuffle=True
    )
    validation_loader = torch.utils.data.DataLoader(
        validation_dataset, batch_size=batch_size,
    )
    return train_loader, validation_loader, num_classes


for model_id in ["llava-hf/llava-1.5-7b-hf"]:
    if model_id == "Salesforce/instructblip-vicuna-7b":

        model = InstructBlipForConditionalGeneration.from_pretrained(
    model_id,
)
        processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")

    elif model_id == "llava-hf/llava-1.5-7b-hf":
        model = LlavaForConditionalGeneration.from_pretrained(
            model_id, 
            torch_dtype=torch.float16, 
            low_cpu_mem_usage=True, 
        )
        model.generation_config.pad_token_id = model.generation_config.eos_token_id
        
        processor = AutoProcessor.from_pretrained(
            model_id, 
        )
    
    model.cuda().eval()
    is_chat = True if "llava" in model_id else False
     
    for task in ["ImageNet"]:

        cap = {}
        set_deterministic(0)
    
        train_loader, validation_loader, num_classes=getdata(task, batch_size=1)
        num_im = 0


        for epoch in range(1):
            with tqdm(validation_loader, unit="batch") as tepoch:
                tepoch.set_description('Epoch '+str(epoch+1))
                
                for ori_images, labels in tepoch:

                    ori_images = torch.clamp(ori_images, 0, 1)
                    ori_images = (ori_images * 255).to(torch.uint8)

                    gts = []
                    pixel_values = None
                    for txtprompt in [
                        "First answer the following three questions: What's in the image? Are there text or numbers in the image? Are there faces in the image? Then, based on the answers, determine the images complexity based on the following factors:\n1. Number of distinct objects\n2. Color variance\n3. Texture complexity\n4. Foreground and background\n5. Symmetry and repetition\n6. Human perception factors, like the presence of human faces or text\nYou will be given the caption, whether there are text or numbers, and whether there are faces in the image. Assign a complexity score such that a higher number means the image is more complex. Note that text and facial details are intrinsically complex because they are crucial to human perception. Here are some examples for scoring:\n\
- Score 1: A plane in a sky\n- Score 2: A t-shirt with a emoji on it\n- Score 3: A dog lying on the grass\n- Score 4: A woman skiing in the snow\n- Score 5: Two kids walking on the beach\n- Score 6: A dinning table full of food\n\
- Score 7: A close-up shot of a old man\n- Score 8: Many people gathering in the stadium\n- Score 9: Newspapers or graphs with text and numbers\nRespond with \"Score: ? out of 9\", where \"?\" is a number between 1 and 9. Then provide explanations."
                    ]:
                        if is_chat:
                            messages = [
                                {"role": "user", "content": [
                                   
                                    {"type": "text", "text": txtprompt},
                                     {"type": "image"},
                                ]}
                            ]
                            txtprompt = processor.apply_chat_template(messages, add_generation_prompt=True)

                        inputs = processor(
                            images=ori_images, text=txtprompt, return_tensors="pt"
                        )
                        if "blip" in model_id:
                            for key, value in inputs.items():
                                if not torch.is_tensor(inputs[key]):
                                    inputs[key] = torch.tensor(inputs[key])
                                inputs[key]=inputs[key].to(device)
                        else:
                            inputs = inputs.to(device)
                        pixel_values = inputs["pixel_values"]
                        
                        outputs = model.generate(
                            **inputs,
                            max_new_tokens=512,
                            pad_token_id=model.config.eos_token_id
                        )
                        generated_text = processor.batch_decode(
                            outputs, skip_special_tokens=True
                        )[0].strip()

                        generated_text = generated_text[re.search("ASSISTANT: ", generated_text).end():]

                        try:
                            generated_text = generated_text[re.search("Score", generated_text).end():]
                        except:
                            pass

                        lscore = 0

                        for c in generated_text:
                            if c.isdigit():
                                lscore = int(c)
                                break
    
                        del ori_images, inputs, outputs
                        cap[num_im]= str(lscore)
                        num_im+=1
                    
        save_filename = task + "_" + model_id.split("/")[0]+ "_score.json"
        with open(save_filename, "w") as f:
            print("*** saving to", save_filename)
            json.dump(cap, f)